{
 "cells": [
  {
   "cell_type": "raw",
   "id": "c14da114-1a4a-487d-9cff-e0e8c30ba366",
   "metadata": {},
   "source": [
    "---\n",
    "sidebar_position: 3\n",
    "title: Querying a SQL DB\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "506e9636",
   "metadata": {},
   "source": [
    "We can replicate our SQLDatabaseChain with Runnables."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "7a927516",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.prompts import ChatPromptTemplate\n",
    "\n",
    "template = \"\"\"Based on the table schema below, write a SQL query that would answer the user's question:\n",
    "{schema}\n",
    "\n",
    "Question: {question}\n",
    "SQL Query:\"\"\"\n",
    "prompt = ChatPromptTemplate.from_template(template)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3f51f386",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.utilities import SQLDatabase"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c3449d6-684b-416e-ba16-90a035835a88",
   "metadata": {},
   "source": [
    "We'll need the Chinook sample DB for this example. There's many places to download it from, e.g. https://database.guide/2-sample-databases-sqlite/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "2ccca6fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "db = SQLDatabase.from_uri(\"sqlite:///./Chinook.db\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "05ba88ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_schema(_):\n",
    "    return db.get_table_info()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "a4eda902",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_query(query):\n",
    "    return db.run(query)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "5046cb17",
   "metadata": {},
   "outputs": [],
   "source": [
    "from operator import itemgetter\n",
    "\n",
    "from langchain.chat_models import ChatOpenAI\n",
    "from langchain.schema.output_parser import StrOutputParser\n",
    "from langchain.schema.runnable import RunnableLambda, RunnableMap\n",
    "\n",
    "model = ChatOpenAI()\n",
    "\n",
    "inputs = {\n",
    "    \"schema\": RunnableLambda(get_schema),\n",
    "    \"question\": itemgetter(\"question\")\n",
    "}\n",
    "sql_response = (\n",
    "        RunnableMap(inputs)\n",
    "        | prompt\n",
    "        | model.bind(stop=[\"\\nSQLResult:\"])\n",
    "        | StrOutputParser()\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "a5552039",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'SELECT COUNT(*) FROM Employee'"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sql_response.invoke({\"question\": \"How many employees are there?\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "d6fee130",
   "metadata": {},
   "outputs": [],
   "source": [
    "template = \"\"\"Based on the table schema below, question, sql query, and sql response, write a natural language response:\n",
    "{schema}\n",
    "\n",
    "Question: {question}\n",
    "SQL Query: {query}\n",
    "SQL Response: {response}\"\"\"\n",
    "prompt_response = ChatPromptTemplate.from_template(template)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "923aa634",
   "metadata": {},
   "outputs": [],
   "source": [
    "full_chain = (\n",
    "    RunnableMap({\n",
    "        \"question\": itemgetter(\"question\"),\n",
    "        \"query\": sql_response,\n",
    "    }) \n",
    "    | {\n",
    "        \"schema\": RunnableLambda(get_schema),\n",
    "        \"question\": itemgetter(\"question\"),\n",
    "        \"query\": itemgetter(\"query\"),\n",
    "        \"response\": lambda x: db.run(x[\"query\"])    \n",
    "    } \n",
    "    | prompt_response \n",
    "    | model\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "e94963d8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content='There are 8 employees.', additional_kwargs={}, example=False)"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "full_chain.invoke({\"question\": \"How many employees are there?\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f358d7b-a721-4db3-9f92-f06913428afc",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
